{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fb2a049",
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24af9556",
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "# flake8: noqa"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70ac1fe8",
   "metadata": {},
   "source": [
    "# Model selection and serving with Ray Tune and Ray Serve\n",
    "\n",
    "```{image} /images/serve.svg\n",
    ":align: center\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "This tutorial will show you an end-to-end example how to train a\n",
    "model using Ray Tune on incrementally arriving data and deploy\n",
    "the model using Ray Serve.\n",
    "\n",
    "A machine learning workflow can be quite simple: You decide on\n",
    "the objective you're trying to solve, collect and annotate the\n",
    "data, and build a model to hopefully solve your problem. But\n",
    "usually the work is not over yet. First, you would likely continue\n",
    "to do some hyperparameter optimization to obtain the best possible\n",
    "model (called *model selection*). Second, your trained model\n",
    "somehow has to be moved to production - in other words, users\n",
    "or services should be enabled to use your model to actually make\n",
    "predictions. This part is called *model serving*.\n",
    "\n",
    "Fortunately, Ray includes two libraries that help you with these\n",
    "two steps: Ray Tune and Ray Serve. And even more, they compliment\n",
    "each other nicely. Most notably, both are able to scale up your\n",
    "workloads easily - so both your model training and serving benefit\n",
    "from additional resources and can adapt to your environment. If you\n",
    "need to train on more data or have more hyperparameters to tune,\n",
    "Ray Tune can leverage your whole cluster for training. If you have\n",
    "many users doing inference on your served models, Ray Serve can\n",
    "automatically distribute the inference to multiple nodes.\n",
    "\n",
    "This tutorial will show you an end-to-end example how to train a MNIST\n",
    "image classifier on incrementally arriving data and automatically\n",
    "serve an updated model on a HTTP endpoint.\n",
    "\n",
    "By the end of this tutorial you will be able to\n",
    "\n",
    "1. Do hyperparameter optimization on a simple MNIST classifier\n",
    "2. Continue to train this classifier from an existing model with\n",
    "   newly arriving data\n",
    "3. Automatically create and serve data deployments with Ray Serve\n",
    "\n",
    "## Roadmap and desired functionality\n",
    "\n",
    "The general idea of this example is that we simulate newly arriving\n",
    "data each day. So at day 0 we might have some initial data available\n",
    "already, but at each day, new data arrives.\n",
    "\n",
    "Our approach here is that we offer two ways to train: From scratch and\n",
    "from an existing model. Maybe you would like to train and select models\n",
    "from scratch each week with all data available until then, e.g. each\n",
    "Sunday, like this:\n",
    "\n",
    "```{code-block} bash\n",
    "# Train with all data available at day 0\n",
    "python tune-serve-integration-mnist.py --from_scratch --day 0\n",
    "```\n",
    "\n",
    "During the other days you might want to improve your model, but\n",
    "not train everything from scratch, saving some cluster resources.\n",
    "\n",
    "```{code-block} bash\n",
    "# Train with data arriving between day 0 and day 1\n",
    "python tune-serve-integration-mnist.py --from_existing --day 1\n",
    "# Train with incremental data on the other days, too\n",
    "python tune-serve-integration-mnist.py --from_existing --day 2\n",
    "python tune-serve-integration-mnist.py --from_existing --day 3\n",
    "python tune-serve-integration-mnist.py --from_existing --day 4\n",
    "python tune-serve-integration-mnist.py --from_existing --day 5\n",
    "python tune-serve-integration-mnist.py --from_existing --day 6\n",
    "# Retrain from scratch every 7th day:\n",
    "python tune-serve-integration-mnist.py --from_scratch --day 7\n",
    "```\n",
    "\n",
    "This example will support both modes. After each model selection run,\n",
    "we will tell Ray Serve to serve an updated model. We also include a\n",
    "small utility to query our served model to see if it works as it should.\n",
    "\n",
    "```{code-block} bash\n",
    "$ python tune-serve-integration-mnist.py --query 6\n",
    "Querying model with example #6. Label = 1, Response = 1, Correct = True\n",
    "```\n",
    "\n",
    "\n",
    "## Imports\n",
    "\n",
    "Let's start with our dependencies. Most of these should be familiar\n",
    "if you worked with PyTorch before. The most notable import for Ray\n",
    "is the ``from ray import tune, serve`` import statement - which\n",
    "includes almost all the things we need from the Ray side."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0376c9c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import json\n",
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "from functools import partial\n",
    "from math import ceil\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import ray\n",
    "from ray import tune, serve\n",
    "from ray.serve.exceptions import RayServeException\n",
    "from ray.tune import CLIReporter\n",
    "from ray.tune.schedulers import ASHAScheduler\n",
    "\n",
    "from torch.utils.data import random_split, Subset\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms import transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58eaafa1",
   "metadata": {},
   "source": [
    "## Data interface\n",
    "\n",
    "Let's start with a simulated data interface. This class acts as the\n",
    "interface between your training code and your database. We simulate\n",
    "that new data arrives each day with a ``day`` parameter. So, calling\n",
    "``get_data(day=3)`` would return all data we received until day 3.\n",
    "We also implement an incremental data method, so calling\n",
    "``get_incremental_data(day=3)`` would return all data collected\n",
    "between day 2 and day 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcf4de94",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MNISTDataInterface(object):\n",
    "    \"\"\"Data interface. Simulates that new data arrives every day.\"\"\"\n",
    "\n",
    "    def __init__(self, data_dir, max_days=10):\n",
    "        self.data_dir = data_dir\n",
    "        self.max_days = max_days\n",
    "\n",
    "        transform = transforms.Compose(\n",
    "            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "        )\n",
    "        self.dataset = MNIST(\n",
    "            self.data_dir, train=True, download=True, transform=transform\n",
    "        )\n",
    "\n",
    "    def _get_day_slice(self, day=0):\n",
    "        if day < 0:\n",
    "            return 0\n",
    "        n = len(self.dataset)\n",
    "        # Start with 30% of the data, get more data each day\n",
    "        return min(n, ceil(n * (0.3 + 0.7 * day / self.max_days)))\n",
    "\n",
    "    def get_data(self, day=0):\n",
    "        \"\"\"Get complete normalized train and validation data to date.\"\"\"\n",
    "        end = self._get_day_slice(day)\n",
    "\n",
    "        available_data = Subset(self.dataset, list(range(end)))\n",
    "        train_n = int(0.8 * end)  # 80% train data, 20% validation data\n",
    "\n",
    "        return random_split(available_data, [train_n, end - train_n])\n",
    "\n",
    "    def get_incremental_data(self, day=0):\n",
    "        \"\"\"Get next normalized train and validation data day slice.\"\"\"\n",
    "        start = self._get_day_slice(day - 1)\n",
    "        end = self._get_day_slice(day)\n",
    "\n",
    "        available_data = Subset(self.dataset, list(range(start, end)))\n",
    "        train_n = int(0.8 * (end - start))  # 80% train data, 20% validation data\n",
    "\n",
    "        return random_split(available_data, [train_n, end - start - train_n])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13612bb2",
   "metadata": {},
   "source": [
    "## PyTorch neural network classifier\n",
    "\n",
    "Next, we will introduce our PyTorch neural network model and the\n",
    "train and test function. These are adapted directly from\n",
    "our {doc}`PyTorch MNIST example </tune/examples/includes/mnist_pytorch>`.\n",
    "We only introduced an additional neural network layer with a configurable\n",
    "layer size. This is not strictly needed for learning good performance on\n",
    "MNIST, but it is useful to demonstrate scenarios where your hyperparameter\n",
    "search space affects the model complexity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2c21aa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    def __init__(self, layer_size=192):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.layer_size = layer_size\n",
    "        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)\n",
    "        self.fc = nn.Linear(192, self.layer_size)\n",
    "        self.out = nn.Linear(self.layer_size, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 3))\n",
    "        x = x.view(-1, 192)\n",
    "        x = self.fc(x)\n",
    "        x = self.out(x)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "\n",
    "\n",
    "def train(model, optimizer, train_loader, device=None):\n",
    "    device = device or torch.device(\"cpu\")\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "\n",
    "def test(model, data_loader, device=None):\n",
    "    device = device or torch.device(\"cpu\")\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (data, target) in enumerate(data_loader):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            outputs = model(data)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += target.size(0)\n",
    "            correct += (predicted == target).sum().item()\n",
    "\n",
    "    return correct / total"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "677ded46",
   "metadata": {},
   "source": [
    "## Tune trainable for model selection\n",
    "\n",
    "We'll now define our Tune trainable function. This function takes\n",
    "a ``config`` parameter containing the hyperparameters we should train\n",
    "the model on, and will start a full training run. This means it\n",
    "will take care of creating the model and optimizer and repeatedly\n",
    "call the ``train`` function to train the model. Also, this function\n",
    "will report the training progress back to Tune."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c29de4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_mnist(\n",
    "    config,\n",
    "    start_model=None,\n",
    "    checkpoint_dir=None,\n",
    "    num_epochs=10,\n",
    "    use_gpus=False,\n",
    "    data_fn=None,\n",
    "    day=0,\n",
    "):\n",
    "    # Create model\n",
    "    use_cuda = use_gpus and torch.cuda.is_available()\n",
    "    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "    model = ConvNet(layer_size=config[\"layer_size\"]).to(device)\n",
    "\n",
    "    # Create optimizer\n",
    "    optimizer = optim.SGD(\n",
    "        model.parameters(), lr=config[\"lr\"], momentum=config[\"momentum\"]\n",
    "    )\n",
    "\n",
    "    # Load checkpoint, or load start model if no checkpoint has been\n",
    "    # passed and a start model is specified\n",
    "    load_dir = None\n",
    "    if checkpoint_dir:\n",
    "        load_dir = checkpoint_dir\n",
    "    elif start_model:\n",
    "        load_dir = start_model\n",
    "\n",
    "    if load_dir:\n",
    "        model_state, optimizer_state = torch.load(os.path.join(load_dir, \"checkpoint\"))\n",
    "        model.load_state_dict(model_state)\n",
    "        optimizer.load_state_dict(optimizer_state)\n",
    "\n",
    "    # Get full training datasets\n",
    "    train_dataset, validation_dataset = data_fn(day=day)\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        train_dataset, batch_size=config[\"batch_size\"], shuffle=True\n",
    "    )\n",
    "\n",
    "    validation_loader = torch.utils.data.DataLoader(\n",
    "        validation_dataset, batch_size=config[\"batch_size\"], shuffle=True\n",
    "    )\n",
    "\n",
    "    for i in range(num_epochs):\n",
    "        train(model, optimizer, train_loader, device)\n",
    "        acc = test(model, validation_loader, device)\n",
    "        if i == num_epochs - 1:\n",
    "            with tune.checkpoint_dir(step=i) as checkpoint_dir:\n",
    "                torch.save(\n",
    "                    (model.state_dict(), optimizer.state_dict()),\n",
    "                    os.path.join(checkpoint_dir, \"checkpoint\"),\n",
    "                )\n",
    "            tune.report(mean_accuracy=acc, done=True)\n",
    "        else:\n",
    "            tune.report(mean_accuracy=acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "513f8db0",
   "metadata": {},
   "source": [
    "## Configuring the search space and starting Ray Tune\n",
    "\n",
    "We would like to support two modes of training the model: Training\n",
    "a model from scratch, and continuing to train a model from an\n",
    "existing one.\n",
    "\n",
    "This is our function to train a number of models with different\n",
    "hyperparameters from scratch, i.e. from all data that is available\n",
    "until the given day. Our search space can thus also contain parameters\n",
    "that affect the model complexity (such as the layer size), since it\n",
    "does not have to be compatible to an existing model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82fcbf6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0):\n",
    "    data_interface = MNISTDataInterface(\"~/data\", max_days=10)\n",
    "    num_examples = data_interface._get_day_slice(day)\n",
    "\n",
    "    config = {\n",
    "        \"batch_size\": tune.choice([16, 32, 64]),\n",
    "        \"layer_size\": tune.choice([32, 64, 128, 192]),\n",
    "        \"lr\": tune.loguniform(1e-4, 1e-1),\n",
    "        \"momentum\": tune.uniform(0.1, 0.9),\n",
    "    }\n",
    "\n",
    "    scheduler = ASHAScheduler(\n",
    "        metric=\"mean_accuracy\",\n",
    "        mode=\"max\",\n",
    "        max_t=num_epochs,\n",
    "        grace_period=1,\n",
    "        reduction_factor=2,\n",
    "    )\n",
    "\n",
    "    reporter = CLIReporter(\n",
    "        parameter_columns=[\"layer_size\", \"lr\", \"momentum\", \"batch_size\"],\n",
    "        metric_columns=[\"mean_accuracy\", \"training_iteration\"],\n",
    "    )\n",
    "\n",
    "    analysis = tune.run(\n",
    "        partial(\n",
    "            train_mnist,\n",
    "            start_model=None,\n",
    "            data_fn=data_interface.get_data,\n",
    "            num_epochs=num_epochs,\n",
    "            use_gpus=True if gpus_per_trial > 0 else False,\n",
    "            day=day,\n",
    "        ),\n",
    "        resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
    "        config=config,\n",
    "        num_samples=num_samples,\n",
    "        scheduler=scheduler,\n",
    "        progress_reporter=reporter,\n",
    "        verbose=0,\n",
    "        name=\"tune_serve_mnist_fromscratch\",\n",
    "    )\n",
    "\n",
    "    best_trial = analysis.get_best_trial(\"mean_accuracy\", \"max\", \"last\")\n",
    "    best_accuracy = best_trial.metric_analysis[\"mean_accuracy\"][\"last\"]\n",
    "    best_trial_config = best_trial.config\n",
    "    best_checkpoint = best_trial.checkpoint.value\n",
    "\n",
    "    return best_accuracy, best_trial_config, best_checkpoint, num_examples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f051b634",
   "metadata": {},
   "source": [
    "To continue training from an existing model, we can use this function\n",
    "instead. It takes a starting model (a checkpoint) as a parameter and\n",
    "the old config.\n",
    "\n",
    "Note that this time the search space does _not_ contain the\n",
    "layer size parameter. Since we continue to train an existing model,\n",
    "we cannot change the layer size mid training, so we just continue\n",
    "to use the existing one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56b26451",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tune_from_existing(\n",
    "    start_model, start_config, num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0\n",
    "):\n",
    "    data_interface = MNISTDataInterface(\"/tmp/mnist_data\", max_days=10)\n",
    "    num_examples = data_interface._get_day_slice(day) - data_interface._get_day_slice(\n",
    "        day - 1\n",
    "    )\n",
    "\n",
    "    config = start_config.copy()\n",
    "    config.update(\n",
    "        {\n",
    "            \"batch_size\": tune.choice([16, 32, 64]),\n",
    "            \"lr\": tune.loguniform(1e-4, 1e-1),\n",
    "            \"momentum\": tune.uniform(0.1, 0.9),\n",
    "        }\n",
    "    )\n",
    "\n",
    "    scheduler = ASHAScheduler(\n",
    "        metric=\"mean_accuracy\",\n",
    "        mode=\"max\",\n",
    "        max_t=num_epochs,\n",
    "        grace_period=1,\n",
    "        reduction_factor=2,\n",
    "    )\n",
    "\n",
    "    reporter = CLIReporter(\n",
    "        parameter_columns=[\"lr\", \"momentum\", \"batch_size\"],\n",
    "        metric_columns=[\"mean_accuracy\", \"training_iteration\"],\n",
    "    )\n",
    "\n",
    "    analysis = tune.run(\n",
    "        partial(\n",
    "            train_mnist,\n",
    "            start_model=start_model,\n",
    "            data_fn=data_interface.get_incremental_data,\n",
    "            num_epochs=num_epochs,\n",
    "            use_gpus=True if gpus_per_trial > 0 else False,\n",
    "            day=day,\n",
    "        ),\n",
    "        resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
    "        config=config,\n",
    "        num_samples=num_samples,\n",
    "        scheduler=scheduler,\n",
    "        progress_reporter=reporter,\n",
    "        verbose=0,\n",
    "        name=\"tune_serve_mnist_fromsexisting\",\n",
    "    )\n",
    "\n",
    "    best_trial = analysis.get_best_trial(\"mean_accuracy\", \"max\", \"last\")\n",
    "    best_accuracy = best_trial.metric_analysis[\"mean_accuracy\"][\"last\"]\n",
    "    best_trial_config = best_trial.config\n",
    "    best_checkpoint = best_trial.checkpoint.value\n",
    "\n",
    "    return best_accuracy, best_trial_config, best_checkpoint, num_examples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a25629c1",
   "metadata": {},
   "source": [
    "## Serving tuned models with Ray Serve\n",
    "\n",
    "Let's now turn to the model serving part with Ray Serve. Serve allows\n",
    "you to deploy your models as multiple _deployments_. Broadly speaking,\n",
    "a deployment handles incoming requests and replies with a result. For\n",
    "instance, our MNIST deployment takes an image as input and outputs the\n",
    "digit it recognized from it. This deployment can be exposed over HTTP.\n",
    "\n",
    "First, we will define our deployment. This loads our PyTorch\n",
    "MNIST model from a checkpoint, takes an image as an input and\n",
    "outputs our digit prediction according to our trained model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0d6a4ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "@serve.deployment(name=\"mnist\", route_prefix=\"/mnist\")\n",
    "class MNISTDeployment:\n",
    "    def __init__(self, checkpoint_dir, config, metrics, use_gpu=False):\n",
    "        self.checkpoint_dir = checkpoint_dir\n",
    "        self.config = config\n",
    "        self.metrics = metrics\n",
    "\n",
    "        use_cuda = use_gpu and torch.cuda.is_available()\n",
    "        self.device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "        model = ConvNet(layer_size=self.config[\"layer_size\"]).to(self.device)\n",
    "\n",
    "        model_state, optimizer_state = torch.load(\n",
    "            os.path.join(self.checkpoint_dir, \"checkpoint\"), map_location=self.device\n",
    "        )\n",
    "        model.load_state_dict(model_state)\n",
    "\n",
    "        self.model = model\n",
    "\n",
    "    def __call__(self, flask_request):\n",
    "        images = torch.tensor(flask_request.json[\"images\"])\n",
    "        images = images.to(self.device)\n",
    "        outputs = self.model(images)\n",
    "        predicted = torch.max(outputs.data, 1)[1]\n",
    "        return {\"result\": predicted.numpy().tolist()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ba14c4a",
   "metadata": {},
   "source": [
    "We would like to have a fixed location where we store the currently\n",
    "active model. We call this directory ``model_dir``. Every time we\n",
    "would like to update our model, we copy the checkpoint of the new\n",
    "model to this directory. We then update the deployment to the new version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bba77923",
   "metadata": {},
   "outputs": [],
   "source": [
    "def serve_new_model(model_dir, checkpoint, config, metrics, day, use_gpu=False):\n",
    "    print(\"Serving checkpoint: {}\".format(checkpoint))\n",
    "\n",
    "    checkpoint_path = _move_checkpoint_to_model_dir(\n",
    "        model_dir, checkpoint, config, metrics\n",
    "    )\n",
    "\n",
    "    serve.start(detached=True)\n",
    "    MNISTDeployment.deploy(checkpoint_path, config, metrics, use_gpu)\n",
    "\n",
    "\n",
    "def _move_checkpoint_to_model_dir(model_dir, checkpoint, config, metrics):\n",
    "    \"\"\"Move backend checkpoint to a central `model_dir` on the head node.\n",
    "    If you would like to run Serve on multiple nodes, you might want to\n",
    "    move the checkpoint to a shared storage, like Amazon S3, instead.\"\"\"\n",
    "    os.makedirs(model_dir, 0o755, exist_ok=True)\n",
    "\n",
    "    checkpoint_path = os.path.join(model_dir, \"checkpoint\")\n",
    "    meta_path = os.path.join(model_dir, \"meta.json\")\n",
    "\n",
    "    if os.path.exists(checkpoint_path):\n",
    "        shutil.rmtree(checkpoint_path)\n",
    "\n",
    "    shutil.copytree(checkpoint, checkpoint_path)\n",
    "\n",
    "    with open(meta_path, \"wt\") as fp:\n",
    "        json.dump(dict(config=config, metrics=metrics), fp)\n",
    "\n",
    "    return checkpoint_path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f779c7bc",
   "metadata": {},
   "source": [
    "Since we would like to continue training from the current existing\n",
    "model, we introduce an utility function that fetches the currently\n",
    "served checkpoint as well as the hyperparameter config and achieved\n",
    "accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005f2787",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_current_model(model_dir):\n",
    "    checkpoint_path = os.path.join(model_dir, \"checkpoint\")\n",
    "    meta_path = os.path.join(model_dir, \"meta.json\")\n",
    "\n",
    "    if not os.path.exists(checkpoint_path) or not os.path.exists(meta_path):\n",
    "        return None, None, None\n",
    "\n",
    "    with open(meta_path, \"rt\") as fp:\n",
    "        meta = json.load(fp)\n",
    "\n",
    "    return checkpoint_path, meta[\"config\"], meta[\"metrics\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c55a5d3",
   "metadata": {},
   "source": [
    "## Putting everything together\n",
    "\n",
    "Now we only need to glue this code together. This is the main\n",
    "entrypoint of the script, and we will define three methods:\n",
    "\n",
    "1. Train new model from scratch with all data\n",
    "2. Continue training from existing model with new data only\n",
    "3. Query the model with test data\n",
    "\n",
    "Internally, this will just call the ``tune_from_scratch`` and\n",
    "``tune_from_existing()`` functions.\n",
    "Both training functions will then call ``serve_new_model()`` to serve\n",
    "the newly trained or updated model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "053bbbfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The query function will send a HTTP request to Serve with some\n",
    "# test data obtained from the MNIST dataset.\n",
    "if __name__ == \"__main__\":\n",
    "    \"\"\"\n",
    "    This script offers training a new model from scratch with all\n",
    "    available data, or continuing to train an existing model\n",
    "    with newly available data.\n",
    "\n",
    "    For instance, we might get new data every day. Every Sunday, we\n",
    "    would like to train a new model from scratch.\n",
    "\n",
    "    Naturally, we would like to use hyperparameter optimization to\n",
    "    find the best model for out data.\n",
    "\n",
    "    First, we might train a model with all data available at this day:\n",
    "\n",
    "    ```{code-block} bash\n",
    "    python tune-serve-integration-mnist.py --from_scratch --day 0\n",
    "    ```\n",
    "\n",
    "    On the coming days, we want to continue to train this model with\n",
    "    newly available data:\n",
    "\n",
    "    ```{code-block} bash\n",
    "    python tune-serve-integration-mnist.py --from_existing --day 1\n",
    "    python tune-serve-integration-mnist.py --from_existing --day 2\n",
    "    python tune-serve-integration-mnist.py --from_existing --day 3\n",
    "    python tune-serve-integration-mnist.py --from_existing --day 4\n",
    "    python tune-serve-integration-mnist.py --from_existing --day 5\n",
    "    python tune-serve-integration-mnist.py --from_existing --day 6\n",
    "    # Retrain from scratch every 7th day:\n",
    "    python tune-serve-integration-mnist.py --from_scratch --day 7\n",
    "    ```\n",
    "\n",
    "    We can also use this script to query our served model\n",
    "    with some test data:\n",
    "\n",
    "    ```{code-block} bash\n",
    "    python tune-serve-integration-mnist.py --query 6\n",
    "    Querying model with example #6. Label = 1, Response = 1, Correct = T\n",
    "    python tune-serve-integration-mnist.py --query 28\n",
    "    Querying model with example #28. Label = 2, Response = 7, Correct = F\n",
    "    ```\n",
    "\n",
    "    \"\"\"\n",
    "    parser = argparse.ArgumentParser(description=\"MNIST Tune/Serve example\")\n",
    "    parser.add_argument(\"--model_dir\", type=str, default=\"~/mnist_tune_serve\")\n",
    "\n",
    "    parser.add_argument(\n",
    "        \"--from_scratch\",\n",
    "        action=\"store_true\",\n",
    "        help=\"Train and select best model from scratch\",\n",
    "        default=True,\n",
    "    )\n",
    "\n",
    "    parser.add_argument(\n",
    "        \"--from_existing\",\n",
    "        action=\"store_true\",\n",
    "        help=\"Train and select best model from existing model\",\n",
    "        default=False,\n",
    "    )\n",
    "\n",
    "    parser.add_argument(\n",
    "        \"--day\",\n",
    "        help=\"Indicate the day to simulate the amount of data available to us\",\n",
    "        type=int,\n",
    "        default=0,\n",
    "    )\n",
    "\n",
    "    parser.add_argument(\n",
    "        \"--query\", help=\"Query endpoint with example\", type=int, default=-1\n",
    "    )\n",
    "\n",
    "    parser.add_argument(\n",
    "        \"--smoke-test\",\n",
    "        action=\"store_true\",\n",
    "        help=\"Finish quickly for testing\",\n",
    "        default=True,\n",
    "    )\n",
    "\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    if args.smoke_test:\n",
    "        ray.init(num_cpus=3, namespace=\"tune-serve-integration\")\n",
    "    else:\n",
    "        ray.init(namespace=\"tune-serve-integration\")\n",
    "\n",
    "    model_dir = os.path.expanduser(args.model_dir)\n",
    "\n",
    "    if args.query >= 0:\n",
    "        import requests\n",
    "\n",
    "        dataset = MNISTDataInterface(\"/tmp/mnist_data\", max_days=0).dataset\n",
    "        data = dataset[args.query]\n",
    "        label = data[1]\n",
    "\n",
    "        # Query our model\n",
    "        response = requests.post(\n",
    "            \"http://localhost:8000/mnist\", json={\"images\": [data[0].numpy().tolist()]}\n",
    "        )\n",
    "\n",
    "        try:\n",
    "            pred = response.json()[\"result\"][0]\n",
    "        except:  # noqa: E722\n",
    "            pred = -1\n",
    "\n",
    "        print(\n",
    "            \"Querying model with example #{}. \"\n",
    "            \"Label = {}, Response = {}, Correct = {}\".format(\n",
    "                args.query, label, pred, label == pred\n",
    "            )\n",
    "        )\n",
    "        sys.exit(0)\n",
    "\n",
    "    gpus_per_trial = 0.5 if not args.smoke_test else 0.0\n",
    "    serve_gpu = True if gpus_per_trial > 0 else False\n",
    "    num_samples = 8 if not args.smoke_test else 1\n",
    "    num_epochs = 10 if not args.smoke_test else 1\n",
    "\n",
    "    if args.from_scratch:  # train everyday from scratch\n",
    "        print(\"Start training job from scratch on day {}.\".format(args.day))\n",
    "        acc, config, best_checkpoint, num_examples = tune_from_scratch(\n",
    "            num_samples, num_epochs, gpus_per_trial, day=args.day\n",
    "        )\n",
    "        print(\n",
    "            \"Trained day {} from scratch on {} samples. \"\n",
    "            \"Best accuracy: {:.4f}. Best config: {}\".format(\n",
    "                args.day, num_examples, acc, config\n",
    "            )\n",
    "        )\n",
    "        serve_new_model(\n",
    "            model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu\n",
    "        )\n",
    "\n",
    "    if args.from_existing:\n",
    "        old_checkpoint, old_config, old_acc = get_current_model(model_dir)\n",
    "        if not old_checkpoint or not old_config or not old_acc:\n",
    "            print(\"No existing model found. Train one with --from_scratch \" \"first.\")\n",
    "            sys.exit(1)\n",
    "        acc, config, best_checkpoint, num_examples = tune_from_existing(\n",
    "            old_checkpoint,\n",
    "            old_config,\n",
    "            num_samples,\n",
    "            num_epochs,\n",
    "            gpus_per_trial,\n",
    "            day=args.day,\n",
    "        )\n",
    "        print(\n",
    "            \"Trained day {} from existing on {} samples. \"\n",
    "            \"Best accuracy: {:.4f}. Best config: {}\".format(\n",
    "                args.day, num_examples, acc, config\n",
    "            )\n",
    "        )\n",
    "        serve_new_model(\n",
    "            model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c8be26a",
   "metadata": {},
   "source": [
    "That's it! We now have an end-to-end workflow to train and update a\n",
    "model every day with newly arrived data. Every week we might retrain\n",
    "the whole model. At every point in time we make sure to serve the\n",
    "model that achieved the best validation set accuracy.\n",
    "\n",
    "There are some ways we might extend this example. For instance, right\n",
    "now we only serve the latest trained model. We could  also choose to\n",
    "route only a certain percentage of users to the new model, maybe to\n",
    "see if the new model really does it's job right. These kind of\n",
    "deployments are called canary deployments.\n",
    "These kind of deployments would also require us to keep more than one\n",
    "model in our ``model_dir`` - which should be quite easy: We could just\n",
    "create subdirectories for each training day.\n",
    "\n",
    "Still, this example should show you how easy it is to integrate the\n",
    "Ray libraries Ray Tune and Ray Serve in your workflow. While both tools\n",
    "also work independently of each other, they complement each other\n",
    "nicely and support a large number of use cases."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}